-
Notifications
You must be signed in to change notification settings - Fork 1.4k
qqmm #2789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
qqmm #2789
Conversation
| bool transpose_; | ||
| }; | ||
|
|
||
| class DualQuantizedMatmul : public UnaryPrimitive { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit of a nit but I think it makes sense to rename this to QuantizedQuantizedMatmul or QQMatmul to better match the name of the op. Dual is also kind of an overloaded term.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree. I think QQMatmul is better, because then the primitive name and the op name are aligned.
| bool is_equivalent(const Primitive& other) const override; | ||
| std::vector<Shape> output_shapes(const std::vector<array>& inputs) override; | ||
| auto state() const { | ||
| return std::make_tuple(group_size_, bits_, mode_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transpose_ should be part of the state here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is a bit unclear and probably should be changed.. transpose is not a member variable, qqmm is always executed in TN layout (transpose = True). I did it this way because, at the moment, quantization always produces a row-major tensor with the last dimension packed, and TN is the only layout supported for mxfp4 and nvfp4 on B200.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see it below in the list under private:. Maybe it should be deleted?
|
|
||
| ds = mx.grad(gmm)(s, x, wq) | ||
|
|
||
| def test_qqmm(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests will should only be run for now if mx.cuda.is_available().
And in fact I'm not sure what the behavior is on older hardware and CUDA toolkits. Do you know what the minimum requirements there are?
| std::optional<int> bits_ /* = std::nullopt */, | ||
| const std::string& mode /* = "nvfp4" */, | ||
| StreamOrDevice s /* = {} */) { | ||
| // currently only simetric quantization is supported for qqmm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // currently only simetric quantization is supported for qqmm | |
| // currently only symmetric quantization is supported for qqmm |
| if (qmode == QuantizationMode::Affine) { | ||
| std::ostringstream msg; | ||
| msg << "[qqmm] Affine quantization is not supported for qqmm."; | ||
| throw std::invalid_argument(msg.str()); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like this was already checked above?
| // https://docs.nvidia.com/cutlass/4.2.1/media/docs/cpp/blackwell_functionality.html | ||
| // because w_q should always be quantized along the reduction dimension | ||
| // and we quantize so that the last dim is packed, we assume that the last dim | ||
| // always the reduction dim so the firat argument in cubals column major is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // always the reduction dim so the firat argument in cubals column major is | |
| // is always the reduction dim so the first argument in cublas column major is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment feels like it belongs in the cublas implementation rather than here.
| auto [w_inner_dims, w_outer_dims] = | ||
| extract_qqmm_dims("qqmm", x, w_q, scales_w, w, group_size, bits); | ||
|
|
||
| // we don't backprope through qunatized w and scales |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // we don't backprope through qunatized w and scales | |
| // we don't backprop through quantized w and scales |
| auto dtype = bfloat16; | ||
| // out dtype can be only bf16 for now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this limitation? It looks like the op can output bf16, fp16 or fp32: https://docs.nvidia.com/cuda/cublas/#id103.
The API should infer the output type from x.
| validate_quantized_input( | ||
| tag, w_q, scales_w, "weight matrix", "scales_w", group_size, bits); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove the strings here since x is not quantized. The original error message prior to the diff here makes sense.
| array x, // input activations | ||
| array w_q, // quantized weights | ||
| array w_scales, | ||
| std::optional<array> w = std::nullopt, // optional bf16 weights for vjp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really love this API where sometimes it takes a w as input and sometimes not. I wonder if it makes sense to change it to something like:
| array x, // input activations | |
| array w_q, // quantized weights | |
| array w_scales, | |
| std::optional<array> w = std::nullopt, // optional bf16 weights for vjp | |
| array x, // input activations | |
| array w, // possibly quantized weights | |
| std::optional<array> scales, // scales for w, if not provided `w` must be unquantized |
So then it will quantize on the fly if w is not quantized and otherwise it will just use w as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And in order to take a vjp, w has to be provided unquantized.
| bits_, | ||
| qmode, | ||
| s); // (K, N_packed), scales | ||
| vjps.push_back(qqmm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A minor problem here is that this function is only once differentiable. I think changing the API as suggested above migth fix that. You always quantize the inputs on the fly when you want gradients.
This PR adds a new operation
mx.qqmm. The current structure is probably neither optimal nor final.General comment
qqmm(quantized weights, bf16 activations).nvfp4, so we need to transpose and quantize again along a different dimension.mxfp8, the recommended recipe is to quantize with 1D blocks and keep two views of the weights (normal and transposed).Therefore,
mx.qqmmtakes bf16 activationsx, quantized weightsw_qand their scales, and optionally bf16 weights plusgroup_size,mode, andbits.In the current implementation, it is the user’s responsibility to ensure that
group_size,bits, andmodematch those used to quantizew_q. This is probably not ideal, and we may want to improve this in the future.Very important details
scalesare repacked on every call for both weights and activations. In the future, we probably want to:fp_quantize.Batched
qqmmis currently not supported; inputs must be 2D. For now it is implemented this way because:CUBLASLT_BATCH_MODE_STRIDEDis not supported for scales.CUBLASLT_BATCH_MODE_POINTER_ARRAYis not supported for arrays with block scaling.We almost certainly want to add batching in the future, but for simplicity
batch_count = 1for now.qqmmis always executed in TN layout (transpose = True).There are several reasons for this, but mainly we always quantize along the reduction dimension, which currently ends up being the last dimension.. I am happy to change this if you think that it is useful to support all layouts for
mxfp8for example. Also, only on B200 only TN layout is supported fornvfp4andmxfp4.Notes
cublas_gemm.cpp: I grouped all common cuBLAS-related functions into a separate helper class incublas_utils.cpp.mxfp8qqmmbehaves slightly differently fromnvfp4: sometimes, for <<1% of the output elements, the result differs from the dequantized reference by exactly 1 ULP in bf16 (seepython/tests/test_quantized.py, line 1027). I do not think this is a bug because:nvfp4the output matches exactly for every tested shape.Therefore, I attribute this to differences in accumulation on tensor cores or other numerical details we do not control.
What this PR lacks [these] because I first want to make sure the rest of the API looks reasonable
addmm-- basicallycis alwaysnullptrnn.QQLinearnn.Linear.to_qqlinear- or similar method to cast tonn.QQLinear(naming is questionable)Examples are in
python/tests/test_quantized.py.Happy to iterate and change anything here!